{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Human numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.text.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#2) [Path('/home/jhoward/.fastai/data/human_numbers/train.txt'),Path('/home/jhoward/.fastai/data/human_numbers/valid.txt')]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.HUMAN_NUMBERS)\n",
    "path.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readnums(d): return ', '.join(o.strip() for o in open(path/d).readlines())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_txt = readnums('train.txt'); train_txt[:80]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_txt = readnums('valid.txt'); valid_txt[-80:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tok = tokenize1(train_txt)\n",
    "valid_tok = tokenize1(valid_txt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsets = Datasets([train_tok, valid_tok], tfms=Numericalize, dl_type=LMDataLoader, splits=[[0], [1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.dataloaders(bs=bs, val_bs=bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsets.show((dsets.train[0][0][:80],))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13017"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dsets.valid[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dls.valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(72, 3)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dls.seq_len, len(dls.valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.8248697916666665"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "13017/72/bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(dls.valid)\n",
    "x1,y1 = next(it)\n",
    "x2,y2 = next(it)\n",
    "x3,y3 = next(it)\n",
    "it.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12992"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.numel()+x2.numel()+x3.numel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the closes multiple of 64 below 13017"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 72]), torch.Size([64, 72]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape,y1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 72]), torch.Size([64, 72]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2.shape,y2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2, 19, 11, 12,  9, 19, 11, 13,  9, 19, 11, 14,  9, 19, 11, 15,  9, 19,\n",
       "        11, 16,  9, 19, 11, 17,  9, 19, 11, 18,  9, 19, 11, 19,  9, 19, 11, 20,\n",
       "         9, 19, 11, 29,  9, 19, 11, 30,  9, 19, 11, 31,  9, 19, 11, 32,  9, 19,\n",
       "        11, 33,  9, 19, 11, 34,  9, 19, 11, 35,  9, 19, 11, 36,  9, 19, 11, 37],\n",
       "       device='cuda:5')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([19, 11, 12,  9, 19, 11, 13,  9, 19, 11, 14,  9, 19, 11, 15,  9, 19, 11,\n",
       "        16,  9, 19, 11, 17,  9, 19, 11, 18,  9, 19, 11, 19,  9, 19, 11, 20,  9,\n",
       "        19, 11, 29,  9, 19, 11, 30,  9, 19, 11, 31,  9, 19, 11, 32,  9, 19, 11,\n",
       "        33,  9, 19, 11, 34,  9, 19, 11, 35,  9, 19, 11, 36,  9, 19, 11, 37,  9],\n",
       "       device='cuda:5')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y1[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = dls.vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x1[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen ,'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in y1[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two , eight thousand thirty three'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x2[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x3[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x1[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand seventy four , eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x2[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x3[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'seven , nine thousand nine hundred eighty eight , nine thousand nine hundred eighty nine , nine thousand nine hundred ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([v[x] for x in x3[-1]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single fully connected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.dataloaders(bs=bs, seq_len=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 3]), torch.Size([64, 3]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = dls.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "40"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nv = len(v); nv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nh=64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n",
    "def acc4 (input,target): return accuracy(input, target[:,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model0(Module):\n",
    "    def __init__(self):\n",
    "        self.i_h = nn.Embedding(nv,nh)  # green arrow\n",
    "        self.h_h = nn.Linear(nh,nh)     # brown arrow\n",
    "        self.h_o = nn.Linear(nh,nv)     # blue arrow\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = self.bn(F.relu(self.h_h(self.i_h(x[:,0]))))\n",
    "        if x.shape[1]>1:\n",
    "            h = h + self.i_h(x[:,1])\n",
    "            h = self.bn(F.relu(self.h_h(h)))\n",
    "        if x.shape[1]>2:\n",
    "            h = h + self.i_h(x[:,2])\n",
    "            h = self.bn(F.relu(self.h_h(h)))\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, Model0(), loss_func=loss4, metrics=acc4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>acc4</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.459452</td>\n",
       "      <td>3.417839</td>\n",
       "      <td>0.144213</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.519120</td>\n",
       "      <td>2.569264</td>\n",
       "      <td>0.456250</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.031360</td>\n",
       "      <td>2.176257</td>\n",
       "      <td>0.459722</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.840601</td>\n",
       "      <td>2.040740</td>\n",
       "      <td>0.463657</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.772740</td>\n",
       "      <td>2.000901</td>\n",
       "      <td>0.463657</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.758649</td>\n",
       "      <td>1.995709</td>\n",
       "      <td>0.464120</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(6, 1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Same thing with a loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model1(Module):\n",
    "    def __init__(self):\n",
    "        self.i_h = nn.Embedding(nv,nh)  # green arrow\n",
    "        self.h_h = nn.Linear(nh,nh)     # brown arrow\n",
    "        self.h_o = nn.Linear(nh,nv)     # blue arrow\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
    "        for i in range(x.shape[1]):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = self.bn(F.relu(self.h_h(h)))\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, Model1(), loss_func=loss4, metrics=acc4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>acc4</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.445585</td>\n",
       "      <td>3.383623</td>\n",
       "      <td>0.194213</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.568218</td>\n",
       "      <td>2.707002</td>\n",
       "      <td>0.425694</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.063069</td>\n",
       "      <td>2.317326</td>\n",
       "      <td>0.460185</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.860497</td>\n",
       "      <td>2.152390</td>\n",
       "      <td>0.466667</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.787315</td>\n",
       "      <td>2.100394</td>\n",
       "      <td>0.467593</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.772113</td>\n",
       "      <td>2.092769</td>\n",
       "      <td>0.467593</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(6, 1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi fully connected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.dataloaders(bs=bs, seq_len=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 20]), torch.Size([64, 20]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = dls.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model2(Module):\n",
    "    def __init__(self):\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
    "        res = []\n",
    "        for i in range(x.shape[1]):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = F.relu(self.h_h(h))\n",
    "            res.append(self.h_o(self.bn(h)))\n",
    "        return torch.stack(res, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, Model2(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.736573</td>\n",
       "      <td>3.754480</td>\n",
       "      <td>0.063566</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.540261</td>\n",
       "      <td>3.523310</td>\n",
       "      <td>0.124826</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.300708</td>\n",
       "      <td>3.304820</td>\n",
       "      <td>0.248735</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.063205</td>\n",
       "      <td>3.128578</td>\n",
       "      <td>0.299777</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>2.861345</td>\n",
       "      <td>3.009128</td>\n",
       "      <td>0.335367</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>2.705495</td>\n",
       "      <td>2.929025</td>\n",
       "      <td>0.353894</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>2.593792</td>\n",
       "      <td>2.878335</td>\n",
       "      <td>0.367832</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.519732</td>\n",
       "      <td>2.850741</td>\n",
       "      <td>0.373140</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>2.475534</td>\n",
       "      <td>2.840007</td>\n",
       "      <td>0.375546</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>2.452727</td>\n",
       "      <td>2.838413</td>\n",
       "      <td>0.375918</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, 1e-4, pct_start=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Maintain state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model3(Module):\n",
    "    def __init__(self):\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        self.h = torch.zeros(bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res = []\n",
    "        if x.shape[0]!=self.h.shape[0]: self.h = torch.zeros(x.shape[0], nh).cuda()\n",
    "        h = self.h\n",
    "        for i in range(x.shape[1]):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = F.relu(self.h_h(h))\n",
    "            res.append(self.bn(h))\n",
    "        self.h = h.detach()\n",
    "        res = torch.stack(res, dim=1)\n",
    "        res = self.h_o(res)\n",
    "        return res\n",
    "    \n",
    "    def reset(self): self.h = torch.zeros(bs, nh).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, Model3(), metrics=accuracy, loss_func=CrossEntropyLossFlat())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.482397</td>\n",
       "      <td>3.442618</td>\n",
       "      <td>0.139980</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.828804</td>\n",
       "      <td>2.455908</td>\n",
       "      <td>0.417783</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.134592</td>\n",
       "      <td>2.153767</td>\n",
       "      <td>0.315203</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.763576</td>\n",
       "      <td>2.096672</td>\n",
       "      <td>0.316940</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.589015</td>\n",
       "      <td>2.090171</td>\n",
       "      <td>0.317088</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.497501</td>\n",
       "      <td>2.057994</td>\n",
       "      <td>0.331374</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.414305</td>\n",
       "      <td>1.895721</td>\n",
       "      <td>0.441195</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.307273</td>\n",
       "      <td>2.044791</td>\n",
       "      <td>0.437872</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>1.165429</td>\n",
       "      <td>1.991641</td>\n",
       "      <td>0.461210</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>1.033335</td>\n",
       "      <td>1.776033</td>\n",
       "      <td>0.542783</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.923316</td>\n",
       "      <td>1.810016</td>\n",
       "      <td>0.564509</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.834117</td>\n",
       "      <td>1.762270</td>\n",
       "      <td>0.565005</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.758906</td>\n",
       "      <td>1.723969</td>\n",
       "      <td>0.591940</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.699892</td>\n",
       "      <td>1.808163</td>\n",
       "      <td>0.578944</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.653839</td>\n",
       "      <td>1.802881</td>\n",
       "      <td>0.592039</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.620560</td>\n",
       "      <td>1.769326</td>\n",
       "      <td>0.614509</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.595637</td>\n",
       "      <td>1.782574</td>\n",
       "      <td>0.616667</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.578359</td>\n",
       "      <td>1.772477</td>\n",
       "      <td>0.623785</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.567210</td>\n",
       "      <td>1.772950</td>\n",
       "      <td>0.623115</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.561052</td>\n",
       "      <td>1.781880</td>\n",
       "      <td>0.621751</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## nn.RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model4(Module):\n",
    "    def __init__(self):\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.rnn = nn.RNN(nh,nh, batch_first=True)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = BatchNorm1dFlat(nh)\n",
    "        self.h = torch.zeros(1, bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.shape[0]!=self.h.shape[1]: self.h = torch.zeros(1, x.shape[0], nh).cuda()\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(self.bn(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, Model4(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.462379</td>\n",
       "      <td>3.272240</td>\n",
       "      <td>0.265749</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.669984</td>\n",
       "      <td>2.254657</td>\n",
       "      <td>0.462872</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.026915</td>\n",
       "      <td>2.119816</td>\n",
       "      <td>0.315923</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.709504</td>\n",
       "      <td>2.164839</td>\n",
       "      <td>0.316964</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.538079</td>\n",
       "      <td>2.037120</td>\n",
       "      <td>0.388790</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.376378</td>\n",
       "      <td>2.241062</td>\n",
       "      <td>0.339459</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.182906</td>\n",
       "      <td>2.094107</td>\n",
       "      <td>0.371429</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.019852</td>\n",
       "      <td>1.614843</td>\n",
       "      <td>0.476141</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.871662</td>\n",
       "      <td>1.549297</td>\n",
       "      <td>0.486880</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.743875</td>\n",
       "      <td>1.525240</td>\n",
       "      <td>0.522867</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.636371</td>\n",
       "      <td>1.434942</td>\n",
       "      <td>0.558606</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.549575</td>\n",
       "      <td>1.398644</td>\n",
       "      <td>0.553646</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.480547</td>\n",
       "      <td>1.357781</td>\n",
       "      <td>0.564410</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.427223</td>\n",
       "      <td>1.290959</td>\n",
       "      <td>0.583606</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.388108</td>\n",
       "      <td>1.209717</td>\n",
       "      <td>0.606944</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.356891</td>\n",
       "      <td>1.256806</td>\n",
       "      <td>0.609722</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.332150</td>\n",
       "      <td>1.269009</td>\n",
       "      <td>0.610045</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.315104</td>\n",
       "      <td>1.244885</td>\n",
       "      <td>0.617956</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.304269</td>\n",
       "      <td>1.261909</td>\n",
       "      <td>0.615501</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.297769</td>\n",
       "      <td>1.279711</td>\n",
       "      <td>0.611533</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-layer GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model5(Module):\n",
    "    def __init__(self):\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = BatchNorm1dFlat(nh)\n",
    "        self.h = torch.zeros(2, bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.shape[0]!=self.h.shape[1]: self.h = torch.zeros(2, x.shape[0], nh).cuda()\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(self.bn(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, Model5(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.666392</td>\n",
       "      <td>2.114901</td>\n",
       "      <td>0.497594</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.436292</td>\n",
       "      <td>1.357266</td>\n",
       "      <td>0.624330</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.678816</td>\n",
       "      <td>1.007875</td>\n",
       "      <td>0.745387</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.329509</td>\n",
       "      <td>0.735918</td>\n",
       "      <td>0.813219</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.168463</td>\n",
       "      <td>0.633921</td>\n",
       "      <td>0.837922</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.089841</td>\n",
       "      <td>0.612871</td>\n",
       "      <td>0.851290</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.051091</td>\n",
       "      <td>0.690696</td>\n",
       "      <td>0.840972</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.031449</td>\n",
       "      <td>0.706523</td>\n",
       "      <td>0.834896</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.020642</td>\n",
       "      <td>0.633427</td>\n",
       "      <td>0.843948</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.014271</td>\n",
       "      <td>0.636002</td>\n",
       "      <td>0.844072</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
